.. _array:

Array
=====

.. currentmodule:: mlx.core

.. autosummary:: 
   :toctree: _autosummary 

    array
    array.astype
    array.item
    array.tolist
    array.dtype
    array.ndim
    array.shape
    array.size
    Dtype
    array.abs
    array.all
    array.any
    array.argmax
    array.argmin
    array.cos
    array.dtype
    array.exp
    array.log
    array.log1p
    array.logsumexp
    array.max
    array.mean
    array.min
    array.prod
    array.reciprocal
    array.reshape
    array.round
    array.rsqrt
    array.sin
    array.split
    array.sqrt
    array.square
    array.sum
    array.transpose
    array.T
    array.var
.. _data_types:

:orphan:

Data Types
==========

.. currentmodule:: mlx.core

The default floating point type is ``float32`` and the default integer type is
``int32``. The table below shows supported values for :obj:`Dtype`. 

.. list-table:: Supported Data Types 
   :widths: 5 3 20
   :header-rows: 1

   * - Type 
     - Bytes
     - Description
   * - ``bool_``
     - 1 
     - Boolean (``True``, ``False``) data type
   * - ``uint8``
     - 1 
     - 8-bit unsigned integer 
   * - ``uint16``
     - 2 
     - 16-bit unsigned integer 
   * - ``uint32``
     - 4 
     - 32-bit unsigned integer 
   * - ``uint64``
     - 8 
     - 64-bit unsigned integer 
   * - ``int8``
     - 1 
     - 8-bit signed integer 
   * - ``int16``
     - 2 
     - 16-bit signed integer 
   * - ``int32``
     - 4 
     - 32-bit signed integer 
   * - ``int64``
     - 8 
     - 64-bit signed integer 
   * - ``float16``
     - 2 
     - 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
   * - ``float32``
     - 4 
     - 32-bit float
.. _devices_and_streams:

Devices and Streams
===================

.. currentmodule:: mlx.core

.. autosummary::
  :toctree: _autosummary

   Device
   default_device
   set_default_device
   Stream
   default_stream
   new_stream
   set_default_stream
.. _fft:

FFT
===

.. currentmodule:: mlx.core.fft

.. autosummary:: 
  :toctree: _autosummary

  fft
  ifft
  fft2
  ifft2
  fftn
  ifftn
  rfft
  irfft
  rfft2
  irfft2
  rfftn
  irfftn
.. _linalg:

Linear Algebra
==============

.. currentmodule:: mlx.core.linalg

.. autosummary:: 
   :toctree: _autosummary 

    norm
.. _nn:

.. currentmodule:: mlx.nn

Neural Networks
===============

Writing arbitrarily complex neural networks in MLX can be done using only
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the
user to write again and again the same simple neural network operations as well
as handle all the parameter state and initialization manually and explicitly.

The module :mod:`mlx.nn` solves this problem by providing an intuitive way of
composing neural network layers, initializing their parameters, freezing them
for finetuning and more.

Quick Start with Neural Networks
---------------------------------

.. code-block:: python

    import mlx.core as mx
    import mlx.nn as nn

    class MLP(nn.Module):
        def __init__(self, in_dims: int, out_dims: int):
            super().__init__()

            self.layers = [
                nn.Linear(in_dims, 128),
                nn.Linear(128, 128),
                nn.Linear(128, out_dims),
            ]

        def __call__(self, x):
            for i, l in enumerate(self.layers):
                x = mx.maximum(x, 0) if i > 0 else x
                x = l(x)
            return x

    # The model is created with all its parameters but nothing is initialized
    # yet because MLX is lazily evaluated
    mlp = MLP(2, 10)

    # We can access its parameters by calling mlp.parameters()
    params = mlp.parameters()
    print(params["layers"][0]["weight"].shape)

    # Printing a parameter will cause it to be evaluated and thus initialized
    print(params["layers"][0])

    # We can also force evaluate all parameters to initialize the model
    mx.eval(mlp.parameters())

    # A simple loss function.
    # NOTE: It doesn't matter how it uses the mlp model. It currently captures
    #       it from the local scope. It could be a positional argument or a
    #       keyword argument.
    def l2_loss(x, y):
        y_hat = mlp(x)
        return (y_hat - y).square().mean()

    # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
    # gradient with respect to `mlp.trainable_parameters()`
    loss_and_grad = nn.value_and_grad(mlp, l2_loss)

.. _module_class:

The Module Class
----------------

The workhorse of any neural network library is the :class:`Module` class. In
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or
:class:`Module` instances. Its main function is to provide a way to
recursively **access** and **update** its parameters and those of its
submodules.

Parameters
^^^^^^^^^^

A parameter of a module is any public member of type :class:`mlx.core.array` (its
name should not start with ``_``). It can be arbitrarily nested in other
:class:`Module` instances or lists and dictionaries.

:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.

A :class:`Module` can also keep track of "frozen" parameters. See the
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
the gradients returned will be with respect to these trainable parameters.


Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^

MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.


Inspecting Modules
^^^^^^^^^^^^^^^^^^

The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the ``MLP`` with:

.. code-block:: python

  print(mlp)

This will display:

.. code-block:: shell

  MLP(
    (layers.0): Linear(input_dims=2, output_dims=128, bias=True)
    (layers.1): Linear(input_dims=128, output_dims=128, bias=True)
    (layers.2): Linear(input_dims=128, output_dims=10, bias=True)
  )

To get more detailed information on the arrays in a :class:`Module` you can use
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
all the parameters in a :class:`Module` do:

.. code-block:: python

   from mlx.utils import tree_map
   shapes = tree_map(lambda p: p.shape, mlp.parameters())

As another example, you can count the number of parameters in a :class:`Module`
with:

.. code-block:: python

   from mlx.utils import tree_flatten
   num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))


Value and Grad
--------------

Using a :class:`Module` does not preclude using MLX's high order function
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,
these function transformations assume pure functions, namely the parameters
should be passed as an argument to the function being transformed.

There is an easy pattern to achieve that with MLX modules

.. code-block:: python

    model = ...

    def f(params, other_inputs):
        model.update(params)  # <---- Necessary to make the model use the passed parameters
        return model(other_inputs)

    f(model.trainable_parameters(), mx.zeros((10,)))

However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only
computes the gradients with respect to the trainable parameters of the model.

In detail:

- it wraps the passed function with a function that calls :meth:`Module.update`
  to make sure the model is using the provided parameters.
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function
  that also computes the gradients with respect to the passed parameters.
- it wraps the returned function with a function that passes the trainable
  parameters as the first argument to the function returned by
  :meth:`mlx.core.value_and_grad`

.. autosummary::
   :toctree: _autosummary

   value_and_grad

.. toctree::

   nn/module
   nn/layers
   nn/functions
   nn/losses
   nn/init
.. _ops:

Operations
==========

.. currentmodule:: mlx.core

.. autosummary:: 
  :toctree: _autosummary

   abs
   add
   all
   allclose 
   any
   arange
   arccos
   arccosh
   arcsin
   arcsinh
   arctan
   arctanh
   argmax
   argmin
   argpartition
   argsort
   array_equal
   broadcast_to
   ceil
   clip
   concatenate
   convolve
   conv1d
   conv2d
   cos
   cosh
   dequantize
   divide
   divmod
   equal
   erf
   erfinv
   exp
   expand_dims
   eye
   flatten
   floor
   floor_divide
   full
   greater
   greater_equal
   identity
   inner
   isnan
   isposinf
   isneginf
   isinf
   less
   less_equal
   linspace
   load
   log
   log2
   log10
   log1p
   logaddexp
   logical_not
   logical_and
   logical_or
   logsumexp
   matmul
   max
   maximum
   mean
   min
   minimum
   moveaxis
   multiply
   negative
   ones
   ones_like
   outer
   partition
   pad
   prod
   quantize
   quantized_matmul
   reciprocal
   repeat
   reshape
   round
   rsqrt
   save
   savez
   savez_compressed
   save_gguf
   save_safetensors
   sigmoid
   sign
   sin
   sinh
   softmax
   sort
   split
   sqrt
   square
   squeeze
   stack
   stop_gradient
   subtract
   sum
   swapaxes
   take
   take_along_axis
   tan
   tanh
   tensordot
   transpose
   tri
   tril
   triu
   var
   where
   zeros
   zeros_like
.. _optimizers:

Optimizers
==========

The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
:mod:`mlx.core` functions. A typical example involves calling
:meth:`Optimizer.update` to update a model's parameters based on the loss
gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
model's parameters and the **optimizer state**.

.. code-block:: python

    # Create a model
    model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
    mx.eval(model.parameters())

    # Create the gradient function and the optimizer
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    optimizer = optim.SGD(learning_rate=learning_rate)

    for e in range(num_epochs):
        for X, y in batch_iterate(batch_size, train_images, train_labels):
            loss, grads = loss_and_grad_fn(model, X, y)

            # Update the model with the gradients. So far no computation has happened.
            optimizer.update(model, grads)

            # Compute the new parameters but also the optimizer state.
            mx.eval(model.parameters(), optimizer.state)

.. currentmodule:: mlx.optimizers

.. autosummary::
   :toctree: _autosummary
   :template: optimizers-template.rst

   OptimizerState
   Optimizer
   SGD
   RMSprop
   Adagrad
   Adafactor
   AdaDelta
   Adam
   AdamW
   Adamax
   Lion
.. _random:

Random
======

Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional ``key`` keyword argument for when more
fine-grained control or explicit state management is needed.

For example, you can generate random numbers with:

.. code-block:: python

  for _ in range(3):
    print(mx.random.uniform())

which will print a sequence of unique pseudo random numbers. Alternatively you
can explicitly set the key:

.. code-block:: python

  key = mx.random.key(0)
  for _ in range(3):
    print(mx.random.uniform(key=key))

which will yield the same pseudo random number at each iteration.

Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
we use a splittable version of Threefry, which is a counter-based PRNG.

.. currentmodule:: mlx.core.random

.. autosummary:: 
  :toctree: _autosummary

   bernoulli
   categorical
   gumbel
   key
   normal
   randint
   seed
   split
   truncated_normal
   uniform
.. _transforms:

Transforms
==========

.. currentmodule:: mlx.core

.. autosummary::
  :toctree: _autosummary

   eval
   grad
   value_and_grad
   jvp
   vjp
   vmap
   simplify
.. _utils:

Tree Utils
==========

In MLX we consider a python tree to be an arbitrarily nested collection of
dictionaries, lists and tuples without cycles. Functions in this module that
return python trees will be using the default python ``dict``, ``list`` and
``tuple`` but they can usually process objects that inherit from any of these.

.. note::
   Dictionaries should have keys that are valid python identifiers.

.. currentmodule:: mlx.utils

.. autosummary:: 
  :toctree: _autosummary

   tree_flatten
   tree_unflatten
   tree_map
